{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2MDUUcFa7O2u"
   },
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-ai-edge/model-explorer/blob/main/example_colabs/custom_data_overlay_demo.ipynb)\n",
    "\n",
    "# Install packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "me12eMftdu4H"
   },
   "outputs": [],
   "source": [
    "# Install tf-nightly & model-explorer.\n",
    "!pip install tf-nightly\n",
    "!pip install --no-deps ai-edge-model-explorer ai-edge-model-explorer-adapter\n",
    "\n",
    "# Install kagglehub (will be used in the next step to download a model)\n",
    "!pip install kagglehub --no-deps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Im5BQ2Qz7na_"
   },
   "source": [
    "# Download MobileNet v3 from Kaggle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xaSxh0rQf4Fj"
   },
   "outputs": [],
   "source": [
    "import kagglehub\n",
    "\n",
    "# This demo uses MobileNet v3, but you can use other models as well\n",
    "path = kagglehub.model_download(\n",
    "    \"google/mobilenet-v3/tfLite/large-075-224-classification\"\n",
    ")\n",
    "model_path = f\"{path}/1.tflite\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mWm4BE_I70hi"
   },
   "source": [
    "# Run the model with test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6cuAg62OvfTK"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# Load the TFLite model and allocate tensors.\n",
    "interpreter = tf.lite.Interpreter(model_path=model_path)\n",
    "interpreter.allocate_tensors()\n",
    "\n",
    "# Get input and output tensors.\n",
    "input_details = interpreter.get_input_details()\n",
    "output_details = interpreter.get_output_details()\n",
    "\n",
    "# Generate random input data.\n",
    "for input_detail in input_details:\n",
    "  input_shape = input_detail['shape']\n",
    "  input_data = np.array(\n",
    "      np.random.random_sample(input_shape), dtype=input_detail['dtype']\n",
    "  )\n",
    "  interpreter.set_tensor(input_detail['index'], input_data)\n",
    "\n",
    "# Run the model on random input data.\n",
    "interpreter.invoke()\n",
    "\n",
    "# Examine the output data (optional)\n",
    "for output_detail in output_details:\n",
    "  print(f\"Output for {output_detail['name']}\")\n",
    "  output_data = interpreter.get_tensor(output_detail['index'])\n",
    "  print(output_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HhwvzBUIO9SZ"
   },
   "source": [
    "# Prepare per-op benchmarking data for Model Explorer\n",
    "## Step 1: Run the benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "2FLsz9OsC6w_"
   },
   "outputs": [],
   "source": [
    "!mkdir -p /tmp/data\n",
    "\n",
    "%cd /tmp/data\n",
    "\n",
    "CPU_PROFILING_PROTO_PATH = \"/tmp/data/mv3-cpu-op-profile.pb\"\n",
    "\n",
    "# In this example, we're using profiling data from Android's Benchmarking tools\n",
    "# that has already been mapped (outside of this Colab) to the Model Explorer schema.\n",
    "\n",
    "# You can overlay per-op data of your choice by following the instructions at\n",
    "# https://github.com/google/model-explorer/wiki/2.-User-Guide#custom-node-data\n",
    "\n",
    "%env MODEL_PATH=$model_path\n",
    "%env CPU_PROFILING_PROTO_PATH=$CPU_PROFILING_PROTO_PATH\n",
    "\n",
    "# Download the tflite model benchmark binary.\n",
    "!wget -nc https://storage.googleapis.com/tensorflow-nightly-public/prod/tensorflow/release/lite/tools/nightly/latest/linux_x86-64_benchmark_model\n",
    "!chmod +x /tmp/data/linux_x86-64_benchmark_model\n",
    "\n",
    "# Run the benchmark locally only using CPU kernels with op_profiling enabled.\n",
    "!./linux_x86-64_benchmark_model --graph=$MODEL_PATH --use_xnnpack=false --num_threads=4 --enable_op_profiling=true --op_profiling_output_mode=proto --op_profiling_output_file=$CPU_PROFILING_PROTO_PATH"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Generate the per-op profiling JSON using benchmark results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "from collections.abc import Sequence\n",
    "import json\n",
    "import os\n",
    "import re\n",
    "from typing import Any, List\n",
    "\n",
    "from tensorflow.lite.profiling.proto import profiling_info_pb2\n",
    "\n",
    "_PER_OP_LATENCY_JSON_TYPE = \"per_op_latency\"\n",
    "_OP_TYPE_JSON_TYPE = \"op_type\"\n",
    "\n",
    "_MODEL_EXPLORER_JSON_TYPES = [_PER_OP_LATENCY_JSON_TYPE, _OP_TYPE_JSON_TYPE]\n",
    "\n",
    "\n",
    "def get_op_profile_json(\n",
    "    op_profile: profiling_info_pb2.OpProfileData,\n",
    "    model_explorer_json_type: str,\n",
    ") -> dict[str, Any]:\n",
    "  \"\"\"Generates the Model Explorer json for the op profile.\n",
    "\n",
    "  Args:\n",
    "    op_profile: profiling_info_pb2.OpProfileData\n",
    "    model_explorer_json_type: Type of model explorer json to generate.\n",
    "\n",
    "  Returns:\n",
    "    Model explorer json for the op profile.\n",
    "\n",
    "  Raises:\n",
    "    ValueError: If the op profile name is not in the expected format.\n",
    "    ValueError: If the model explorer json type is not supported.\n",
    "  \"\"\"\n",
    "  op_profile_key_re = re.findall(r\":(\\d+)$\", op_profile.name)\n",
    "  if not op_profile_key_re:\n",
    "    raise ValueError(\"Op profile name is not in the expected format.\")\n",
    "  op_profile_key = op_profile_key_re[0]\n",
    "\n",
    "  if model_explorer_json_type == _PER_OP_LATENCY_JSON_TYPE:\n",
    "    return {\n",
    "        op_profile_key: {\n",
    "            \"value\": op_profile.inference_microseconds.avg / 1000.0\n",
    "        }\n",
    "    }\n",
    "  elif model_explorer_json_type == _OP_TYPE_JSON_TYPE:\n",
    "    return {op_profile_key: {\"value\": op_profile.node_type}}\n",
    "  else:\n",
    "    raise ValueError(\n",
    "        \"Unsupported model explorer json type: %s\" % model_explorer_json_type\n",
    "    )\n",
    "\n",
    "\n",
    "def generate_model_explorer_json(\n",
    "    benchmark_profiling_proto_paths: List[str],\n",
    "    output_path: str,\n",
    "    model_explorer_json_type: str,\n",
    ") -> dict[str, Any]:\n",
    "  \"\"\"Generates the Model Explorer json for the benchmark profiling proto.\"\"\"\n",
    "  if not benchmark_profiling_proto_paths:\n",
    "    raise ValueError(\"At least one profiling proto path should be provided.\")\n",
    "\n",
    "  if model_explorer_json_type not in _MODEL_EXPLORER_JSON_TYPES:\n",
    "    raise ValueError(\n",
    "        f\"Unsupported model explorer json type: {model_explorer_json_type}\"\n",
    "    )\n",
    "\n",
    "  output_json = collections.defaultdict(dict)\n",
    "  for proto_path in benchmark_profiling_proto_paths:\n",
    "    if not os.path.isfile(proto_path):\n",
    "      raise ValueError(f\"File {proto_path} does not exist.\")\n",
    "\n",
    "    with open(proto_path, \"rb\") as f:\n",
    "      benchmark_profiling_proto = (\n",
    "          profiling_info_pb2.BenchmarkProfilingData.FromString(f.read())\n",
    "      )\n",
    "      for (\n",
    "          subgraph_profile\n",
    "      ) in benchmark_profiling_proto.runtime_profile.subgraph_profiles:\n",
    "        subgraph_profile_json = {}\n",
    "        for op_profile in subgraph_profile.per_op_profiles:\n",
    "          subgraph_profile_json.update(\n",
    "              get_op_profile_json(op_profile, model_explorer_json_type)\n",
    "          )\n",
    "        output_json[subgraph_profile.subgraph_name][\n",
    "            \"results\"\n",
    "        ] = subgraph_profile_json\n",
    "\n",
    "        if model_explorer_json_type == _PER_OP_LATENCY_JSON_TYPE:\n",
    "          output_json[subgraph_profile.subgraph_name][\"gradient\"] = [\n",
    "              {\"stop\": 0, \"bgColor\": \"green\"},\n",
    "              {\"stop\": 0.33, \"bgColor\": \"yellow\"},\n",
    "              {\"stop\": 0.67, \"bgColor\": \"orange\"},\n",
    "              {\"stop\": 1, \"bgColor\": \"red\"},\n",
    "          ]\n",
    "\n",
    "  if output_path:\n",
    "    with open(output_path, \"w\") as f:\n",
    "      json.dump(output_json, f)\n",
    "  else:\n",
    "    print(json.dumps(output_json, indent=2))\n",
    "\n",
    "\n",
    "CPU_PROFILING_JSON_PATH = \"/tmp/data/mv3-cpu-op-profile.json\"\n",
    "\n",
    "# Generate pure CPU per-op profiling JSON.\n",
    "generate_model_explorer_json(\n",
    "    [CPU_PROFILING_PROTO_PATH],\n",
    "    CPU_PROFILING_JSON_PATH,\n",
    "    _PER_OP_LATENCY_JSON_TYPE,\n",
    ")\n",
    "\n",
    "# Download the XNNPACK per-op profiling JSON from storage.\n",
    "!wget -nc https://storage.googleapis.com/tfweb/model-explorer-demo/mv3-xnnpack-op-profile.json\n",
    "XNNPACK_PROFILING_JSON_PATH = \"/tmp/data/mv3-xnnpack-op-profile.json\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oCgFNFP4WAzE"
   },
   "source": [
    "# Visualize the model with per op latency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6SVyHrthM_hz"
   },
   "outputs": [],
   "source": [
    "import model_explorer\n",
    "\n",
    "config = model_explorer.config()\n",
    "(\n",
    "    config.add_model_from_path(model_path)\n",
    "    .add_node_data_from_path(CPU_PROFILING_JSON_PATH)\n",
    "    .add_node_data_from_path(XNNPACK_PROFILING_JSON_PATH)\n",
    ")\n",
    "\n",
    "model_explorer.visualize_from_config(config)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
